import os
import sys
import numpy as np
import torch
import torch.nn as nn

from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image

home_dir = os.path.expanduser("~")


class INEX(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.data = []
        self.c1, self.l = [], []
        self.om = []

        for subdir in sorted(os.listdir(root)):
            subdir_path = os.path.join(root, subdir)
            if os.path.isdir(subdir_path):
                for filename in sorted(os.listdir(subdir_path)):
                    if filename.endswith('.png'):
                        if 'all' in filename:
                            continue
                        filepath = os.path.join(subdir_path, filename)
                        image = Image.open(filepath).convert('RGB')
                        self.data.append(image)

                        img_info = filename.split('_')
                        self.om.append(img_info[0])
                        if img_info[0] == 'o':
                            self.c1.append(int(img_info[2]))
                            self.l.append(0)
                        elif img_info[0] == 'm':
                            self.c1.append(int(img_info[2]))
                            self.l.append(int(img_info[3]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        if self.transform is not None:
            image = self.transform(image)
        return self.om[idx], image, self.c1[idx], self.l[idx]


data_path = os.path.join(home_dir, 'data/inex/cifar10_ns_135pair_norm_logsp_128/')
set_dir = os.listdir(data_path)

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

ds = INEX(data_path, transform=transform)
dl = torch.utils.data.DataLoader(ds, batch_size=128, shuffle=False, num_workers=0)


@torch.no_grad()
def check_inex(model):
    softmax = nn.Softmax(dim=1)
    oms = []
    ls = []
    probs_c1 = []
    probs_max = []
    entropy = []
    pred = []
    y_true = []

    for om, x_, c1, l in dl:
        x_ = x_.cuda()
        out = model(x_)
        out = softmax(out)
        ls.extend(l)
        oms.extend(om)
        probs_max.append(out.max(dim=1)[0])
        probs_c1.append(
            torch.gather(out, 1, c1.long().unsqueeze(1).cuda()).squeeze())
        entropy.append(-(out * torch.log(out)).sum(dim=1))
        pred.append(out.argmax(dim=1))
        y_true.append(c1)

    probs_c1 = torch.cat(probs_c1)
    probs_max = torch.cat(probs_max)
    entropy = torch.cat(entropy)
    pred = torch.cat(pred)
    y_true = torch.cat(y_true)

    ls = torch.tensor(ls)
    is_mix = [True if om == 'm' else False for om in oms]
    is_mix = torch.tensor(is_mix)

    # probs_c1_means = []
    probs_max_means = []
    entropies = []
    for lambda_ in ls.unique():
        selected_index = (ls == lambda_) & is_mix
        # probs_c1_mean = probs_c1[selected_index].mean()
        # probs_c1_means.append(probs_c1_mean.item())

        probs_max_mean = probs_max[selected_index].mean()
        probs_max_means.append(probs_max_mean.item())

        et = entropy[selected_index].mean()
        entropies.append(et.item())

    # check oratio
    oratio = probs_c1[is_mix==False].mean()
    probs_max_means.insert(0, oratio.item())
    entropies.insert(0, entropy[ls==1].mean())

    print(f"Uncertainty Max Probability means: {probs_max_means[-1]}")
    print(f"Uncertainty Entropies: {entropies[-1]}")
